import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from matplotlib import ticker

# Define the Raman model function WITHOUT offset (since we'll subtract it from data)
def raman_model_no_offset(L, beta, alpha, I0_fixed):
    """
    Raman generation model function WITHOUT an offset term.
    R = beta * I0_fixed * L * np.exp(-alpha * L)

    Args:
        L (np.array): Fiber length (in km).
        beta (float): Raman generation coefficient.
        alpha (float): Fiber attenuation coefficient (in km^-1).
        I0_fixed (float): Fixed initial light intensity (in mW).

    Returns:
        np.array: Total generated Raman power.
    """
    return beta * I0_fixed * L * np.exp(-alpha * L)


# --- How to Use the Code ---
if __name__ == "__main__":
    # Define the fixed offset to subtract from ALL counts
    FIXED_OFFSET_SUBTRACTION = 20 

    # Original Data for the three curves
    # We will modify 'y' values based on FIXED_OFFSET_SUBTRACTION before fitting
    original_data_curves = {
        '13 dBm': {
            'x': np.array([50, 100, 150, 200]), # L(km)
            'y': np.array([6850, 3260, 560, 200]), # R.Counts
            'snr': np.array([60, 58, 44, 33]), # SNR (DB)
            'I0_dBm': 13 # Initial intensity for this curve
        },
        '3 dBm': {
            'x': np.array([50, 100, 150, 200]), # L(km)
            'y': np.array([1300, 300, 80, 40]), # R.Counts
            'snr': np.array([48, 44, 34, 0]), # SNR (DB)
            'I0_dBm': 3 # Initial intensity for this curve
        },
        '-7 dBm': {
            'x': np.array([50, 100, 150, 200]), # L(km)
            'y': np.array([200, 70, 20, 30]), # R.Counts
            'snr': np.array([43, 31, 13, 0]), # SNR (DB)
            'I0_dBm': -7 # Initial intensity for this curve
        }
    }

    # Constant uncertainty on Y data (this remains the uncertainty on the raw measurement)
    uncertainty_y = 15 # +/- 15 counts

    # Value to subtract for launch powers in the legend
    POWER_OFFSET_DB = 15.45

    # --- Single Plot with Logarithmic Scales ---
    plt.figure(figsize=(12, 7))

    # Colors and styles for curves and fits
    colors = {'13 dBm': 'blue', '3 dBm': 'green', '-7 dBm': 'orange'}
    fit_line_styles = {'13 dBm': '-', '3 dBm': '--', '-7 dBm': ':'}
    
    # List to collect all points with SNR < 30
    all_low_snr_x = []
    all_low_snr_y = []

    # Store fit results to add to legend in a specific order
    fit_results_for_legend = []

    # Add the model function to the legend as the first item
    plt.plot([], [], 'k--', label=r'Raman Model Function: $R(L) = \beta \cdot I_0 \cdot L \cdot e^{-\alpha L}$') # Empty plot for legend entry

    # Iterate over each data set to plot and fit
    for power_level, data in original_data_curves.items():
        original_x_km = data['x']
        original_y_counts_raw = data['y'] # Keep raw counts for original data plot
        original_snr = data['snr']
        I0_dBm = data['I0_dBm']
        I0_value_mW = 10**(I0_dBm / 10.0)
        
        # Calculate the modified "launch power" for the legend
        modified_I0_dBm_for_legend = I0_dBm - POWER_OFFSET_DB

        print(f"\n--- Fitting for Power: {power_level} ---")
        print(f"Original initial intensity I0: {I0_dBm} dBm = {I0_value_mW:.3f} mW")
        print(f"Initial intensity I0 (for legend): {modified_I0_dBm_for_legend:.2f} dBm")

        # --- APPLY THE FIXED OFFSET SUBTRACTION TO Y DATA FOR FITTING ---
        # Ensure values don't go below 1e-1 for logarithmic plotting and stability
        y_for_fit = np.maximum(original_y_counts_raw - FIXED_OFFSET_SUBTRACTION, 1e-1)
        
        # For plotting original points, we still use the raw counts
        x_for_plot_data = original_x_km
        y_for_plot_data = original_y_counts_raw
        
        # SNR points still use the raw data values for their position
        snr_for_fit = original_snr

        # Initial guesses and bounds for parameters (beta, alpha) - NO R_offset (già sottratto)
        if power_level == '13 dBm':
            p0_initial_guesses = [3.0e+01, 0.028] 
            bounds = ([0, 0], [np.inf, np.inf]) # Force beta, alpha >= 0
        elif power_level == '3 dBm':
            p0_initial_guesses = [1.2e+02, 0.046]
            bounds = ([0, 0], [np.inf, np.inf]) 
        elif power_level == '-7 dBm':
            p0_initial_guesses = [1.4e+02, 0.041] 
            bounds = ([0, 0], [np.inf, np.inf]) 
        
        try:
            # Perform the fit on the SUBTRACTED data with the model WITHOUT offset
            popt, pcov = curve_fit(lambda L, beta, alpha: raman_model_no_offset(L, beta, alpha, I0_value_mW),
                                   original_x_km, y_for_fit, p0=p0_initial_guesses, bounds=bounds)

            # Extract fitted parameters
            fitted_beta, fitted_alpha = popt
            
            # Calculate standard errors
            perr = np.sqrt(np.diag(pcov))
            beta_error = perr[0]
            alpha_error = perr[1]

            print(f"  Fitted parameters (after {FIXED_OFFSET_SUBTRACTION} Counts subtraction):")
            print(f"    Beta (β): {fitted_beta:.2e} Counts/(mW*km) +/- {beta_error:.2e}")
            print(f"    Alpha (α): {fitted_alpha:.4f} km^-1 +/- {alpha_error:.4f}")
            
            # Generate points for the fitted curve for plotting (logarithmic scale for L)
            L_fit = np.logspace(np.log10(original_x_km.min()), np.log10(original_x_km.max() * 1.5), 500)
            # Calculate R_fit using the model WITHOUT offset, then ADD BACK the fixed offset for plotting
            R_fit_model_only = raman_model_no_offset(L_fit, fitted_beta, fitted_alpha, I0_value_mW)
            R_fit_total = R_fit_model_only + FIXED_OFFSET_SUBTRACTION # Add back for visualization
            
            # Ensure values are positive for logarithmic display.
            R_fit_total[R_fit_total <= 0] = 1e-1 

            # Store fit results to add to legend later
            fit_results_for_legend.append({
                'L_fit': L_fit,
                'R_fit_total': R_fit_total,
                'color': colors[power_level],
                'style': fit_line_styles[power_level],
                'label': (f'Raman Model Fit (Launch Power: {modified_I0_dBm_for_legend:.2f} dBm)\n' +
                          r'$\beta={:.2e} \pm {:.2e}$ Counts/(mW$\cdot$km)'.format(fitted_beta, beta_error) + '\n' +
                          r'$\alpha={:.3f} \pm {:.3f}$ km$^{{-1}}$'.format(fitted_alpha, alpha_error))
            })

        except RuntimeError as e:
            print(f"  Error in fitting for {power_level}: {e}")
            print("  You might need to adjust initial guesses (p0) or parameter bounds,")
            print("  or verify the quality of experimental data for this curve.")

        # Plot original data points with Y-axis error bars
        # Note: uncertainty_y remains the same as it's for the raw data
        plt.errorbar(x_for_plot_data, y_for_plot_data, yerr=uncertainty_y,
                     fmt='o', markersize=5,
                     markerfacecolor=colors[power_level], # Set fill color same as line color
                     markeredgecolor=colors[power_level], # Set edge color same as line color
                     ecolor=colors[power_level], elinewidth=2.5, capsize=4,
                     label=f'Original Data Points ({power_level}) with Uncertainty')

        # Identify points with SNR < 30
        low_snr_indices = np.where(snr_for_fit < 30)
        # Use original_x_km and original_y_counts_raw for plotting SNR points
        all_low_snr_x.extend(original_x_km[low_snr_indices])
        all_low_snr_y.extend(original_y_counts_raw[low_snr_indices])

    # Plot all fitted curves after adding the model function to ensure correct legend order
    for fit_data in fit_results_for_legend:
        plt.plot(fit_data['L_fit'], fit_data['R_fit_total'], fit_data['style'], 
                 color=fit_data['color'], linewidth=1.5, label=fit_data['label'])

    # Plot all low SNR points with a solid red circle
    if all_low_snr_x:
        plt.plot(all_low_snr_x, all_low_snr_y, 'o', color='red', markersize=6, markerfacecolor='red', markeredgecolor='red', markeredgewidth=1,
                 label='Points with SNR < 30')

    # Set axes to logarithmic scale
    plt.xscale('log')
    plt.yscale('log')
    
    # Set specific X-axis ticks in km for the logarithmic scale
    x_major_ticks = np.array([50, 100, 150, 200, 300])
    plt.gca().xaxis.set_major_locator(ticker.FixedLocator(x_major_ticks))
    plt.gca().xaxis.set_major_formatter(ticker.ScalarFormatter())

    # Set Y-axis ticks for a finer (logarithmic) scale
    y_major_ticks = np.array([1, 10, 100, 1000, 10000])
    plt.gca().yaxis.set_major_locator(ticker.FixedLocator(y_major_ticks))
    plt.gca().yaxis.set_major_formatter(ticker.ScalarFormatter())

    # Add minor ticks for the Y-axis
    plt.gca().yaxis.set_minor_locator(ticker.LogLocator(base=10.0, subs=np.arange(1.0, 10.0) * 0.1, numticks=10))
    plt.gca().yaxis.set_minor_formatter(ticker.NullFormatter())
    
    plt.xlabel('Fiber Length (km)')
    plt.ylabel('Raman Power (Counts)')
    
    # --- MODIFIED TITLE (removed 20 Counts subtraction note) ---
    plt.title(f'Raman Noise Investigation for Different Sensing Power Levels (Integration Time: 100 ms)')
    
    # Legend outside the plot, top right
    plt.legend(title='Curve Legend', bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0.)
    
    plt.grid(True, which="both", ls="-", alpha=0.6)
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.show()

    # Print some results for verification
    print(f"\nOriginal Values for curve '13 dBm':")
    print("X values (all):", original_data_curves['13 dBm']['x'])
    print("Y values (all):", original_data_curves['13 dBm']['y'])